{% comment %}
// vim: set syntax=asm :
/* mmm 16 x 5:

    ymm0 ymm4 ymm8
    ymm1 ymm5 ymm9
    ymm2 ymm6 ymm10
    ymm3 ymm7 ymm11

System V ABI:
    args: rdi, rsi, rdx, rcx, r8, r9
    preserve: rbx, rsp, rbp, r12, r13, r14, r15
    scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11
    return: rax (+rdx)

Windows ABI:
    args: RCX, RDX, R8, R9
    preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15
    scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15
    return: rax (+rdx)
*/
{% endcomment %}

{% include "preamble.tmpliq" type:"f32", size:"32x3", suffix:suffix, G:G %}

{{L}}clear:
    vzeroall
    jmp     {{L}}non_linear_loop

{{L}}add_mat_mul:
    mov     rbx,    [rdi + 8]    // k
    mov     rcx,    [rdi + 24]   // B
    mov     rax,    [rdi + 16]   // A

    mov     r8,    [rdi + 32]   // packing

    test    rbx,    rbx
    jz      {{L}}non_linear_loop

    cmp     r8, 1
    jz      {{L}}main_loop_packed_packed_f32_f16

    cmp     r8, 2
    jz      {{L}}main_loop_packed_packed_f16_f32

    cmp     r8, 3
    jz      {{L}}main_loop_packed_packed_f16_f16

{{L}}main_loop_packed_packed:
    {% include "4x3/packed_packed_loop1/avx.tmpli" %}

    dec             rbx
    jnz             {{L}}main_loop_packed_packed

    jmp             {{L}}non_linear_loop

{{L}}main_loop_packed_packed_f32_f16:
	// Load col of A
	vmovaps			ymm12,	[rax]

	// Fill 3 cols of B
    vpbroadcastw    xmm13,  word ptr [rcx + 0]
    vpbroadcastw    xmm14,  word ptr [rcx + 2]
    vpbroadcastw    xmm15,  word ptr [rcx + 4]

    vcvtph2ps       ymm13, xmm13
    vcvtph2ps       ymm14, xmm14
    vcvtph2ps       ymm15, xmm15

	// N.B. Stepping cols in inner loop
	vfmadd231ps		ymm0,	ymm12, ymm13
	vfmadd231ps		ymm4,	ymm12, ymm14
	vfmadd231ps		ymm8,	ymm12, ymm15

	vmovaps			ymm12,	[rax+32]

	vfmadd231ps		ymm1,	ymm12, ymm13
	vfmadd231ps		ymm5,	ymm12, ymm14
	vfmadd231ps		ymm9,	ymm12, ymm15

	vmovaps			ymm12,	[rax+64]

	vfmadd231ps		ymm2,	ymm12, ymm13
	vfmadd231ps		ymm6,	ymm12, ymm14
	vfmadd231ps		ymm10,	 ymm12, ymm15

	vmovaps			ymm12,	[rax+96]

	vfmadd231ps		ymm3,	ymm12, ymm13
	vfmadd231ps		ymm7,	ymm12, ymm14
	vfmadd231ps		ymm11,	ymm12, ymm15

    add             rcx,    6
    add             rax,    128

    dec             rbx
    jnz             {{L}}main_loop_packed_packed_f32_f16

    jmp             {{L}}non_linear_loop
    
{{L}}main_loop_packed_packed_f16_f32:
	// Load col of A
	vmovaps			xmm12,	[rax]

	// Fill 3 cols of B
    vbroadcastss    ymm13, dword ptr [rcx + 0]
    vbroadcastss    ymm14, dword ptr [rcx + 4]
    vbroadcastss    ymm15, dword ptr [rcx + 8]

    vcvtph2ps       ymm12, xmm12

	vfmadd231ps		ymm0,	ymm12, ymm13
	vfmadd231ps		ymm4,	ymm12, ymm14
	vfmadd231ps		ymm8,	ymm12, ymm15

	vmovaps			xmm12,	[rax+16]
    vcvtph2ps       ymm12, xmm12

	vfmadd231ps		ymm1,	ymm12, ymm13
	vfmadd231ps		ymm5,	ymm12, ymm14
	vfmadd231ps		ymm9,	ymm12, ymm15

	vmovaps			xmm12,	[rax+32]
    vcvtph2ps       ymm12, xmm12

	vfmadd231ps		ymm2,	ymm12, ymm13
	vfmadd231ps		ymm6,	ymm12, ymm14
	vfmadd231ps		ymm10,	 ymm12, ymm15

	vmovaps			xmm12,	[rax+48]
    vcvtph2ps       ymm12, xmm12

	vfmadd231ps		ymm3,	ymm12, ymm13
	vfmadd231ps		ymm7,	ymm12, ymm14
	vfmadd231ps		ymm11,	ymm12, ymm15

    add             rcx,    12
    add             rax,    64

    dec             rbx
    jnz             {{L}}main_loop_packed_packed_f16_f32

    jmp             {{L}}non_linear_loop

{{L}}main_loop_packed_packed_f16_f16:
	// Load col of A
	vmovaps			xmm12,	[rax]

	// Fill 3 cols of B
    vpbroadcastw    xmm13,  word ptr [rcx + 0]
    vpbroadcastw    xmm14,  word ptr [rcx + 2]
    vpbroadcastw    xmm15,  word ptr [rcx + 4]

    vcvtph2ps       ymm12, xmm12
    vcvtph2ps       ymm13, xmm13
    vcvtph2ps       ymm14, xmm14
    vcvtph2ps       ymm15, xmm15

	vfmadd231ps		ymm0,	ymm12, ymm13
	vfmadd231ps		ymm4,	ymm12, ymm14
	vfmadd231ps		ymm8,	ymm12, ymm15

	vmovaps			xmm12,	[rax+16]
    vcvtph2ps       ymm12, xmm12

	vfmadd231ps		ymm1,	ymm12, ymm13
	vfmadd231ps		ymm5,	ymm12, ymm14
	vfmadd231ps		ymm9,	ymm12, ymm15

	vmovaps			xmm12,	[rax+32]
    vcvtph2ps       ymm12, xmm12

	vfmadd231ps		ymm2,	ymm12, ymm13
	vfmadd231ps		ymm6,	ymm12, ymm14
	vfmadd231ps		ymm10,	 ymm12, ymm15

	vmovaps			xmm12,	[rax+48]
    vcvtph2ps       ymm12, xmm12

	vfmadd231ps		ymm3,	ymm12, ymm13
	vfmadd231ps		ymm7,	ymm12, ymm14
	vfmadd231ps		ymm11,	ymm12, ymm15

    add             rcx,    6
    add             rax,    64

    dec             rbx
    jnz             {{L}}main_loop_packed_packed_f16_f16

    jmp             {{L}}non_linear_loop

// NON LINEAR / ADDC

{% include "fma_mmm_f32_scalars.tmpliq" from:0, to:11, type:"f32" %}
{% include "fma_mmm_f32_per_rows.tmpliq" mr:32, from:0, to:11, type:"f32" %}
{% include "fma_mmm_f32_per_cols.tmpliq" mr:32, from:0, to:11, type:"f32" %}
{% include "fma_mmm_load_tile.tmpliq" from:0, to:11 %}

{{L}}add_unicast:
    mov     r8,    [rdi + 8]           // c ptr
    mov     rsi,    [rdi + 16]          // row stride
    mov     rbx,    [rdi + 24]          // col stride

    cmp     rsi, 4
    jne     {{L}}unicast_generic

    lea             r9,  [ r8 + rbx ]
    lea             r10, [ r9 + rbx]
    lea             r11, [ r10 + rbx ]

{% for col in (0..2) %}
    {% for row in (0..3) %}
        vmovups     ymm12, [ r{{col|plus:8}} ]
        add         r{{col|plus:8}}, 32
        vaddps      ymm{{col|times:4|plus:row}}, ymm{{col|times:4|plus:row}}, ymm12
    {% endfor %}
{% endfor %}

    jmp    {{L}}non_linear_loop

{{L}}unicast_generic:
    mov     eax,    0
{% for i in (0..3) %}
    pinsrd  xmm14, eax, {{i}}
    add     eax,    esi
{% endfor %}
{% for i in (0..3) %}
    pinsrd  xmm15, eax, {{i}}
    add     eax,    esi
{% endfor %}

//  mov r12, [0]
    vperm2f128      ymm14,  ymm14, ymm15,         32 // ymm14 <- xmm14::xmm15

    lea             r9,  [ r8 + rsi * 8 ]
    lea             r10, [ r9 + rsi * 8 ]
    lea             r11, [ r10 + rsi * 8 ]

{% for col in (0..2) %}
   {% for row in (0..3) %}
      vpcmpeqd      ymm15,  ymm15, ymm15
      vgatherdps    ymm12,  [ r{{row|plus:8}} + ymm14 ], ymm15
      add           r{{row|plus:8}}, rbx
      vaddps        ymm{{col|times:4|plus:row}}, ymm{{col|times:4|plus:row}}, ymm12
   {% endfor %}
{% endfor %}

    jmp    {{L}}non_linear_loop


{{L}}add_row_col_products:
    mov             rax, [ rdi + 8 ]
    mov             rbx, [ rdi + 16 ]

    vbroadcastss    ymm13, dword ptr [rbx]
    vbroadcastss    ymm14, dword ptr [rbx + 4]
    vbroadcastss    ymm15, dword ptr [rbx + 8]
{% for i in (0..3) %}
    vmovups         ymm12,  [rax + {{i|times:32}}]
    vfmadd231ps     ymm{{0|plus:i}}, ymm12, ymm13
    vfmadd231ps     ymm{{4|plus:i}}, ymm12, ymm14
    vfmadd231ps     ymm{{8|plus:i}}, ymm12, ymm15
{% endfor %}
    jmp    {{L}}non_linear_loop

{{L}}store:
    mov     r8,     [rdi + 8]           // c ptr
    mov     rsi,    [rdi + 16]          // row stride
    mov     rbx,    [rdi + 24]          // col stride
    mov     r11,    [rdi + 32]          // item size

    lea     r9,     [ r8 + rbx ]
    lea     r10,    [ r8 + 2 * rbx ]

    cmp     r11, 2
    je      {{L}}store_f16

    cmp         rsi, 4
    jne         {{L}}store_strides_generic

    {% for col in (0..2) %}
        {% for row in (0..3) %}
            vmovups ymmword ptr [r{{col|plus:8}}], ymm{{col|times:4|plus:row}}
            add     r{{col|plus:8}}, 32
       {% endfor %}
    {% endfor %}

    jmp     {{L}}non_linear_loop

{{L}}store_strides_generic:

    {% for col in (0..2) %}
       {% for row in (0..3) %}
           {% for i in (0..3) %}
                vextractps  dword ptr [r{{col | plus: 8}}], xmm{{col | times:4 | plus:row}}, {{i}}
                add         r{{col | plus: 8}}, rsi
           {% endfor %}
           vperm2f128  ymm{{col | times:4 | plus:row}}, ymm{{col | times:4 | plus:row}}, ymm{{col | times:4 | plus:row}}, 1
           {% for i in (0..3) %}
                vextractps  dword ptr [r{{col | plus: 8}}], xmm{{col | times:4|plus:row}}, {{i}}
                add         r{{col | plus: 8}}, rsi
           {% endfor %}
       {% endfor %}
    {% endfor %}
    jmp     {{L}}non_linear_loop

{{L}}store_f16:

    {% for reg in (0..11) %}
        vcvtps2ph   xmm{{reg}}, ymm{{reg}}, 0
    {% endfor %}

    cmp         rsi, 2
	jne {{L}}store_generic_f16

    {% for col in (0..2) %}
        {% for row in (0..3) %}
            vmovups [r{{col|plus:8}} + {{row|times:16}}], xmm{{col|times:4|plus:row}}
        {% endfor %}
    {% endfor %}

    jmp     {{L}}non_linear_loop
    
{{L}}store_generic_f16:
    {% for col in (0..2) %}
        {% for vec in (0..3) %}
            {% for row in (0..7) %}
                pextrw  word ptr [r{{col|plus:8}}], xmm{{col|times:4|plus:vec}}, {{row}}
                add         r{{col|plus:8}}, rsi
            {% endfor %}
        {% endfor %}
    {% endfor %}

    jmp     {{L}}non_linear_loop

{% include "postamble.tmpliq" type:"f32", size:"32x3", suffix:suffix, G:G, L:L %}
